In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import torch
import numpy as np
In [2]:
from torchvision import models
# Loading vgg19 model and extracting the feature part
vgg = models.vgg19(pretrained = True).features
#Freezing the parameters in vgg
for param in vgg.parameters():
    param.requires_grad_(False)
In [3]:
#move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg.to(device)
Out[3]:
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU(inplace=True)
  (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (22): ReLU(inplace=True)
  (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (24): ReLU(inplace=True)
  (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (26): ReLU(inplace=True)
  (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (29): ReLU(inplace=True)
  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (31): ReLU(inplace=True)
  (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (33): ReLU(inplace=True)
  (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (35): ReLU(inplace=True)
  (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
In [4]:
#download the MS COCO dataset for training
!wget http://images.cocodataset.org/zips/test2017.zip
!mkdir './dataset'
!unzip -q ./test2017.zip -d './dataset'
--2020-11-11 02:08:24--  http://images.cocodataset.org/zips/test2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.78.60
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.78.60|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6646970404 (6.2G) [application/zip]
Saving to: ‘test2017.zip’

test2017.zip        100%[===================>]   6.19G  37.1MB/s    in 2m 4s   

2020-11-11 02:10:28 (51.3 MB/s) - ‘test2017.zip’ saved [6646970404/6646970404]

In [4]:
from torchvision import datasets
import torchvision.transforms as transforms
#Transform train images
batch_size = 4
num_workers = 0
train_transform = transforms.Compose([
                  transforms.Resize((264, 264)),
                  transforms.RandomCrop(256),
                  transforms.ToTensor(),
                  transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                       std = [0.229, 0.224, 0.225])
])
train_data = datasets.ImageFolder('./dataset', transform = train_transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size)
In [31]:
from PIL import Image
import torchvision.transforms as transforms
#Load image
def load_image(img_path):
    image = Image.open(img_path).convert('RGB')
    #print(image.size)
    image_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406],
                             std = [0.229, 0.224, 0.225])
    ])
    
    image = image_transform(image)[:3, :, :].unsqueeze(0)
    #print(image.shape)
    
    
    return image
    
In [6]:
#load style image
style_image = load_image('great_wave.jpg')
style_image = style_image.to(device)
In [7]:
# Un-normalize image tensors
def denormalize(tensor):
    image = tensor.to('cpu').clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1, 2, 0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    
    return image
In [8]:
#for name, layer in vgg._modules.items():
    #print(name)
    #print(layer)
In [8]:
import torch.nn as nn
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
    def forward(self, x):  
        layers = {'3': 'relu1_2',
                  '8': 'relu2_2', 
                  '17': 'relu3_4', 
                  '22': 'relu4_2',
                  '26': 'relu4_4',  ## content representation
                  '35': 'relu5_4'}
        features = {}
        for name, layer in vgg._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
                
        return features      
        
In [9]:
import torch.nn.functional as F
class transformer(nn.Module):
    def __init__(self):
        super(transformer, self).__init__()
        self.conv_block = nn.Sequential(
        conv(3, 32, 9, 1),
        nn.ReLU(),
        conv(32, 64, 3, 2),
        nn.ReLU(),
        conv(64, 128, 3, 2),
        nn.ReLU()
        )
        
        self.residual_block = nn.Sequential(
        ResidualBlock(128),
        ResidualBlock(128),
        ResidualBlock(128),
        ResidualBlock(128),
        ResidualBlock(128)
        )
        #add relu activation
        self.deconv_block = nn.Sequential(
        deconv(128, 64, 3, 2, 1),
        nn.ReLU(),
        deconv(64, 32, 3, 2, 1),
        nn.ReLU(),
        conv(32, 3, 9, 1, normalize = False)
        
        )
        
    def forward(self, x):
        x = self.conv_block(x)
        x = self.residual_block(x)
        x = self.deconv_block(x)
        
        return x
In [10]:
class conv(nn.Module):
    
    def __init__(self,in_channels, out_channels, kernel_size, stride, normalize = True):
        super(conv, self).__init__()
    
        self.reflection_pad = nn.ReflectionPad2d(kernel_size//2)
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        self.norm = nn.InstanceNorm2d(out_channels, affine = True) if normalize else None
    
    def forward(self, x):
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        if self.norm is not None:
            x = self.norm(x)
          
        return x
In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv(channels, channels, 3, 1)
        self.conv2 = conv(channels, channels, 3, 1)
        
    def forward(self, x):
        in_x = x
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x) +in_x
        
        return x
In [12]:
class deconv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, normalize = True):
        super(deconv, self).__init__()
        #self.reflection_pad = nn.ReflectionPad2d(kernel_size//2)
        padding_size = kernel_size//2
        self.deconv_layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride,padding_size, output_padding)
        self.norm = nn.InstanceNorm2d(out_channels, affine = True) if normalize else None
        
    def forward(self, x):
        #x = self.reflection_pad(x)
        x = self.deconv_layer(x)
        if self.norm is not None:
            x = self.norm(x)
            
        return x
        
        
In [13]:
#gram matrix
def gram_matrix(x):
    b, d, h, w = x.size()
    tensor = x.view(b, d, h*w)
    tensor_transpose = tensor.transpose(1, 2)
    gram = torch.bmm(tensor,tensor_transpose)/(d*h*w)
    
    return gram
In [14]:
vgg_net = VGG()
transformer_net = transformer().to(device)
In [15]:
#compute style gram matrix
style_features = vgg_net(style_image)
style_gram = { layer:gram_matrix(style_features[layer]) for layer in style_features}
In [16]:
import os
import torch.optim as optim
from PIL import Image

learning_rate = 0.001
optimizer = optim.Adam(transformer_net.parameters(), lr = learning_rate)
criterion = nn.MSELoss().to(device)
epochs = 1
content_weight = 1
style_weight = 12
checkpoint_path = 'checkpoints'
images_path = 'train_results'

os.makedirs(checkpoint_path, exist_ok = True)
os.makedirs(images_path, exist_ok = True)

for epoch in range(epochs):
    for batch, (images,_) in enumerate(train_loader):
        
        batch_size = images.shape[0]
        #print(batch)
        #print(images.shape)
        images = images.to(device)
        optimizer.zero_grad()
        output_images = transformer_net(images)
        #print(output_images.shape)
        
        # get features
        features = vgg_net(images)
        output_features = vgg_net(output_images)
        content_loss = content_weight*criterion( output_features['relu2_2'], features['relu2_2'])
        
        style_loss = 0
        for layer_name, layer in output_features.items():
            gram_f = gram_matrix(layer)
            style_loss += criterion(gram_f, style_gram[layer_name][:batch_size])
            style_loss *=  style_weight
            
        total_loss = content_loss + style_loss
        
        total_loss.backward()
        optimizer.step()
            
        if batch%400 == 399 or batch == len(train_loader)-1:
            print('Batch {}/{}'.format(batch+1, len(train_loader)))
            print('Total loss {}'.format(total_loss.item()))
            #get an image
            fig, ax = plt.subplots(1, 2)
            input_image = images[0].clone().detach()
            input_image = denormalize(input_image)
            ax[0].imshow(input_image)
            ax[0].set_title('Content image')
            ax[0].set_xticks([])
            ax[0].set_yticks([])
            
            output_image = output_images[0].clone().detach()
            output_image = denormalize(output_image)
            ax[1].imshow(output_image)
            ax[1].set_title('Stylized image')
            ax[1].set_xticks([])
            ax[1].set_yticks([])
            plt.savefig(os.path.join(images_path,'result_{}.png'.format(batch+1)))
            plt.show()
            
            torch.save(transformer_net.state_dict(), os.path.join(checkpoint_path,'model_{}'.format(batch+1)))
             
            
        
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 64, 64])) that is different to the input size (torch.Size([4, 64, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 128, 128])) that is different to the input size (torch.Size([4, 128, 128])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 256, 256])) that is different to the input size (torch.Size([4, 256, 256])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 512, 512])) that is different to the input size (torch.Size([4, 512, 512])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 400/10168
Total loss 44.687782287597656
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 800/10168
Total loss 22.06247901916504
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 1200/10168
Total loss 16.27564239501953
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 1600/10168
Total loss 15.917465209960938
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 2000/10168
Total loss 14.430047988891602
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 2400/10168
Total loss 14.091354370117188
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 2800/10168
Total loss 13.166362762451172
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 3200/10168
Total loss 11.666818618774414
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 3600/10168
Total loss 11.936600685119629
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 4000/10168
Total loss 11.357491493225098
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 4400/10168
Total loss 11.756816864013672
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 4800/10168
Total loss 14.453180313110352
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 5200/10168
Total loss 11.594239234924316
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 5600/10168
Total loss 12.628231048583984
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 6000/10168
Total loss 13.654515266418457
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 6400/10168
Total loss 9.585670471191406
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 6800/10168
Total loss 9.375436782836914
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 7200/10168
Total loss 9.322410583496094
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 7600/10168
Total loss 8.270928382873535
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 8000/10168
Total loss 9.798877716064453
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 8400/10168
Total loss 8.695219039916992
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 8800/10168
Total loss 9.502508163452148
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 9200/10168
Total loss 10.261441230773926
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 9600/10168
Total loss 8.499666213989258
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 10000/10168
Total loss 10.684791564941406
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 64, 64])) that is different to the input size (torch.Size([2, 64, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 128, 128])) that is different to the input size (torch.Size([2, 128, 128])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 256, 256])) that is different to the input size (torch.Size([2, 256, 256])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1, 512, 512])) that is different to the input size (torch.Size([2, 512, 512])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Batch 10168/10168
Total loss 7.845559120178223
In [22]:
#save final model
torch.save(transformer_net.state_dict(),'final_model')
In [33]:
#Test the model
import glob
test_results_path = 'test_results'
os.makedirs(test_results_path, exist_ok = True)
transformer_model = transformer().to(device)
transformer_model.load_state_dict(torch.load('final_model'))
transformer_model.eval()

sample_images_path = list(glob.glob('test_images/*'))
#print(sample_images_path)
for index, sample_image_path in enumerate(sample_images_path):
    #print(sample_image_path)
    test_image = load_image(sample_image_path)
    test_image = test_image.to(device)
    test_output = transformer_model(test_image)
    test_output = test_output.clone().detach()
    stylized_image = denormalize(test_output)
    
    content_image = denormalize(test_image.clone().detach())
    
    fig, ax = plt.subplots(1, 2, figsize = (10, 20))
    ax[0].imshow(content_image)
    ax[0].set_title('Content_image')
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    
    ax[1].imshow(stylized_image)
    ax[1].set_title('Stylized_image')
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    plt.show()
    plt.savefig(os.path.join(test_results_path,'output_{}'.format(index)))
    
    
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
In [ ]:
    
In [ ]: